//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Print MCInst instructions to .ptx format.
//
//===----------------------------------------------------------------------===//

#include "MCTargetDesc/NVPTXInstPrinter.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstrInfo.h"
#include "llvm/MC/MCSubtargetInfo.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormattedStream.h"
#include <cctype>
using namespace llvm;

#define DEBUG_TYPE "asm-printer"

#include "NVPTXGenAsmWriter.inc"

NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
                                   const MCRegisterInfo &MRI)
    : MCInstPrinter(MAI, MII, MRI) {}

void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) const {
  // Decode the virtual register
  // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
  unsigned RCId = (Reg.id() >> 28);
  switch (RCId) {
  default: report_fatal_error("Bad virtual register encoding");
  case 0:
    // This is actually a physical register, so defer to the autogenerated
    // register printer
    OS << getRegisterName(Reg);
    return;
  case 1:
    OS << "%p";
    break;
  case 2:
    OS << "%rs";
    break;
  case 3:
    OS << "%r";
    break;
  case 4:
    OS << "%rd";
    break;
  case 5:
    OS << "%f";
    break;
  case 6:
    OS << "%fd";
    break;
  case 7:
    OS << "%h";
    break;
  case 8:
    OS << "%hh";
    break;
  }

  unsigned VReg = Reg.id() & 0x0FFFFFFF;
  OS << VReg;
}

void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
                                 StringRef Annot, const MCSubtargetInfo &STI,
                                 raw_ostream &OS) {
  printInstruction(MI, Address, OS);

  // Next always print the annotation.
  printAnnotation(OS, Annot);
}

void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
                                    raw_ostream &O) {
  const MCOperand &Op = MI->getOperand(OpNo);
  if (Op.isReg()) {
    unsigned Reg = Op.getReg();
    printRegName(O, Reg);
  } else if (Op.isImm()) {
    O << markup("<imm:") << formatImm(Op.getImm()) << markup(">");
  } else {
    assert(Op.isExpr() && "Unknown operand kind in printOperand");
    Op.getExpr()->print(O, &MAI);
  }
}

void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
                                    const char *Modifier) {
  const MCOperand &MO = MI->getOperand(OpNum);
  int64_t Imm = MO.getImm();

  if (strcmp(Modifier, "ftz") == 0) {
    // FTZ flag
    if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
      O << ".ftz";
  } else if (strcmp(Modifier, "sat") == 0) {
    // SAT flag
    if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
      O << ".sat";
  } else if (strcmp(Modifier, "relu") == 0) {
    // RELU flag
    if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
      O << ".relu";
  } else if (strcmp(Modifier, "base") == 0) {
    // Default operand
    switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
    default:
      return;
    case NVPTX::PTXCvtMode::NONE:
      break;
    case NVPTX::PTXCvtMode::RNI:
      O << ".rni";
      break;
    case NVPTX::PTXCvtMode::RZI:
      O << ".rzi";
      break;
    case NVPTX::PTXCvtMode::RMI:
      O << ".rmi";
      break;
    case NVPTX::PTXCvtMode::RPI:
      O << ".rpi";
      break;
    case NVPTX::PTXCvtMode::RN:
      O << ".rn";
      break;
    case NVPTX::PTXCvtMode::RZ:
      O << ".rz";
      break;
    case NVPTX::PTXCvtMode::RM:
      O << ".rm";
      break;
    case NVPTX::PTXCvtMode::RP:
      O << ".rp";
      break;
    case NVPTX::PTXCvtMode::RNA:
      O << ".rna";
      break;
    }
  } else {
    llvm_unreachable("Invalid conversion modifier");
  }
}

void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
                                    const char *Modifier) {
  const MCOperand &MO = MI->getOperand(OpNum);
  int64_t Imm = MO.getImm();

  if (strcmp(Modifier, "ftz") == 0) {
    // FTZ flag
    if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
      O << ".ftz";
  } else if (strcmp(Modifier, "base") == 0) {
    switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
    default:
      return;
    case NVPTX::PTXCmpMode::EQ:
      O << ".eq";
      break;
    case NVPTX::PTXCmpMode::NE:
      O << ".ne";
      break;
    case NVPTX::PTXCmpMode::LT:
      O << ".lt";
      break;
    case NVPTX::PTXCmpMode::LE:
      O << ".le";
      break;
    case NVPTX::PTXCmpMode::GT:
      O << ".gt";
      break;
    case NVPTX::PTXCmpMode::GE:
      O << ".ge";
      break;
    case NVPTX::PTXCmpMode::LO:
      O << ".lo";
      break;
    case NVPTX::PTXCmpMode::LS:
      O << ".ls";
      break;
    case NVPTX::PTXCmpMode::HI:
      O << ".hi";
      break;
    case NVPTX::PTXCmpMode::HS:
      O << ".hs";
      break;
    case NVPTX::PTXCmpMode::EQU:
      O << ".equ";
      break;
    case NVPTX::PTXCmpMode::NEU:
      O << ".neu";
      break;
    case NVPTX::PTXCmpMode::LTU:
      O << ".ltu";
      break;
    case NVPTX::PTXCmpMode::LEU:
      O << ".leu";
      break;
    case NVPTX::PTXCmpMode::GTU:
      O << ".gtu";
      break;
    case NVPTX::PTXCmpMode::GEU:
      O << ".geu";
      break;
    case NVPTX::PTXCmpMode::NUM:
      O << ".num";
      break;
    case NVPTX::PTXCmpMode::NotANumber:
      O << ".nan";
      break;
    }
  } else {
    llvm_unreachable("Empty Modifier");
  }
}

void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
                                     raw_ostream &O, const char *Modifier) {
  if (Modifier) {
    const MCOperand &MO = MI->getOperand(OpNum);
    int Imm = (int) MO.getImm();
    if (!strcmp(Modifier, "volatile")) {
      if (Imm)
        O << ".volatile";
    } else if (!strcmp(Modifier, "addsp")) {
      switch (Imm) {
      case NVPTX::PTXLdStInstCode::GLOBAL:
        O << ".global";
        break;
      case NVPTX::PTXLdStInstCode::SHARED:
        O << ".shared";
        break;
      case NVPTX::PTXLdStInstCode::LOCAL:
        O << ".local";
        break;
      case NVPTX::PTXLdStInstCode::PARAM:
        O << ".param";
        break;
      case NVPTX::PTXLdStInstCode::CONSTANT:
        O << ".const";
        break;
      case NVPTX::PTXLdStInstCode::GENERIC:
        break;
      default:
        llvm_unreachable("Wrong Address Space");
      }
    } else if (!strcmp(Modifier, "sign")) {
      if (Imm == NVPTX::PTXLdStInstCode::Signed)
        O << "s";
      else if (Imm == NVPTX::PTXLdStInstCode::Unsigned)
        O << "u";
      else if (Imm == NVPTX::PTXLdStInstCode::Untyped)
        O << "b";
      else if (Imm == NVPTX::PTXLdStInstCode::Float)
        O << "f";
      else
        llvm_unreachable("Unknown register type");
    } else if (!strcmp(Modifier, "vec")) {
      if (Imm == NVPTX::PTXLdStInstCode::V2)
        O << ".v2";
      else if (Imm == NVPTX::PTXLdStInstCode::V4)
        O << ".v4";
    } else
      llvm_unreachable("Unknown Modifier");
  } else
    llvm_unreachable("Empty Modifier");
}

void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
                                    const char *Modifier) {
  const MCOperand &MO = MI->getOperand(OpNum);
  int Imm = (int)MO.getImm();
  if (Modifier == nullptr || strcmp(Modifier, "version") == 0) {
    O << Imm; // Just print out PTX version
  } else if (strcmp(Modifier, "aligned") == 0) {
    // PTX63 requires '.aligned' in the name of the instruction.
    if (Imm >= 63)
      O << ".aligned";
  } else
    llvm_unreachable("Unknown Modifier");
}

void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
                                       raw_ostream &O, const char *Modifier) {
  printOperand(MI, OpNum, O);

  if (Modifier && !strcmp(Modifier, "add")) {
    O << ", ";
    printOperand(MI, OpNum + 1, O);
  } else {
    if (MI->getOperand(OpNum + 1).isImm() &&
        MI->getOperand(OpNum + 1).getImm() == 0)
      return; // don't print ',0' or '+0'
    O << "+";
    printOperand(MI, OpNum + 1, O);
  }
}

void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
                                       raw_ostream &O, const char *Modifier) {
  const MCOperand &Op = MI->getOperand(OpNum);
  assert(Op.isExpr() && "Call prototype is not an MCExpr?");
  const MCExpr *Expr = Op.getExpr();
  const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
  O << Sym.getName();
}
